{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "8db19215-aa60-4878-91ab-9fdcf454834f",
   "metadata": {},
   "source": [
    "# Fine-tuning a Text Classifier with PyTorch Lightning and AI Accelerators\n",
    "![Transformer block with an added classification head](./fine-tune.png)\n",
    "\n",
    "Image Source: [How to Fine-Tune Language Models: First Principles to Scalable Performance](https://pub.towardsai.net/how-to-fine-tune-language-models-first-principles-to-scalable-performance-78f42b02f112)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9b35182b-7c72-4edd-874b-87e252d9fd88",
   "metadata": {},
   "source": [
    "## 1. Preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "db4c5486-9081-477f-ac84-43eeaa3865cc",
   "metadata": {},
   "source": [
    "### Set a Random Seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "410134ee-3631-4442-91bb-1174c7a4702c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "42"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import lightning as L\n",
    "\n",
    "L.seed_everything(42, workers=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9729905b-38ae-4c02-b760-dacd998e7f12",
   "metadata": {},
   "source": [
    "### Read Data and Encode Labels Numerically"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac2bb7e6-e725-482c-9670-e1d93c4cc297",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from datasets import load_dataset\n",
    " \n",
    "# If you don't have a local data file to run this notebook, proceed as follows to download a dataset from Hugging Face:\n",
    "all_data = load_dataset(\"intel/polite-guard\", split=\"train\")\n",
    "all_data = all_data.to_pandas().sample(n=500)\n",
    "\n",
    "# Assuming that your data is in \"data.csv\" with columns \"label\" and \"text\", comment out the previous lines and uncomment the following line:\n",
    "# all_data = pd.read_csv(\"data.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fec4952c-baaf-4228-8e14-821afa12880a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'impolite': 0, 'neutral': 1, 'polite': 2, 'somewhat polite': 3}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Assign label mappings\n",
    "unique_labels = sorted(all_data[\"label\"].unique())\n",
    "num_labels = len(unique_labels)\n",
    "\n",
    "id2label = {k:v for k, v in enumerate(unique_labels)}\n",
    "label2id = {v:k for k, v in id2label.items()}\n",
    "\n",
    "all_data[\"label\"] = all_data[\"label\"].map(label2id)\n",
    "label2id"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47289997-1b21-4a6f-9bf9-e6ed805f18cb",
   "metadata": {},
   "source": [
    "### Do a Train | Validation | Test Split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b8241061-0cce-4878-9859-7c9207340170",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "# First split: 80% train, 20% temp (to be split into val and test)\n",
    "train_df, temp = train_test_split(all_data, test_size=0.2, shuffle=True)\n",
    "\n",
    "# Second split: 50% val, 50% test from the 20%\n",
    "val_df, test_df = train_test_split(temp, test_size=0.5, shuffle=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9480b0a8-38da-4409-8602-3ed7698fc47a",
   "metadata": {},
   "source": [
    "## 2. Tokenization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9ddcf60-3a32-4dba-97f0-840cbca36d0d",
   "metadata": {},
   "source": [
    "### Pick a Pretrained Transformer and Load its Tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9d5a26f9-7712-442d-b3c7-05a9995c786b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_ckpt = \"bert-base-uncased\" # Pretrained transformer model for fine-tuning\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6461373d-188a-453a-a6e8-b1630d29dbcd",
   "metadata": {},
   "source": [
    "### Tokenize and Create DataLoaders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af693e52-ad78-4821-beac-7029cb916165",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "\n",
    "batch_size = 32\n",
    "num_workers = 7\n",
    "\n",
    "class TextDataset(Dataset):\n",
    "    def __init__(self, data, tokenizer, max_length=512):\n",
    "        self.data = data\n",
    "        self.tokenizer = tokenizer\n",
    "        self.max_length = max_length\n",
    "        self.is_dataframe = hasattr(data, \"iloc\")  # Check if the data is a dataframe\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.data)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        if self.is_dataframe:\n",
    "            row = self.data.iloc[idx]\n",
    "            text = row[\"text\"]\n",
    "            label = row[\"label\"]\n",
    "        else:\n",
    "            text = self.data[idx]\n",
    "            label = None  # No label if data is a list of strings\n",
    "\n",
    "        # Tokenize the input text\n",
    "        encoding = self.tokenizer(\n",
    "            text,\n",
    "            max_length=self.max_length,\n",
    "            padding=\"max_length\",\n",
    "            truncation=True,\n",
    "            return_tensors=\"pt\"\n",
    "        )\n",
    "\n",
    "        item = {\n",
    "            \"input_ids\": encoding[\"input_ids\"].squeeze(0),\n",
    "            \"attention_mask\": encoding[\"attention_mask\"].squeeze(0),\n",
    "        }\n",
    "\n",
    "        if label is not None:\n",
    "            item[\"label\"] = torch.tensor(label, dtype=torch.long)\n",
    "\n",
    "        return item\n",
    "\n",
    "\n",
    "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
    "\n",
    "# Create Dataset objects\n",
    "train_dataset = TextDataset(train_df, tokenizer)\n",
    "val_dataset = TextDataset(val_df, tokenizer)\n",
    "test_dataset = TextDataset(test_df, tokenizer)\n",
    "\n",
    "# Create DataLoaders\n",
    "train_loader = DataLoader(\n",
    "    train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers\n",
    ")\n",
    "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)\n",
    "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f004f0c-3448-4cfa-a492-ab4848bd2ac7",
   "metadata": {},
   "source": [
    "## 3. Fine-tuning"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "daa6b90e-e52f-46d9-9bcb-64a72c519df7",
   "metadata": {},
   "source": [
    "### Add a Classification Head to the Pretrained Transformer and Prepare the Model for Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "3ddd2905-496b-4cc9-9bbc-af24861a0f0d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "import torchmetrics\n",
    "from torchmetrics.classification import F1Score\n",
    "\n",
    "from transformers import (\n",
    "    AutoModelForSequenceClassification,\n",
    "    get_linear_schedule_with_warmup,\n",
    ")\n",
    "\n",
    "\n",
    "class LightningModel(L.LightningModule):\n",
    "    def __init__(self, model_name, num_labels, label2id, id2label, learning_rate=5e-5, weight_decay=0.01):\n",
    "        super().__init__()\n",
    "        self.save_hyperparameters()\n",
    "\n",
    "        # Initialize hyperparameters and model\n",
    "        self.learning_rate = learning_rate\n",
    "        self.weight_decay = weight_decay\n",
    "        self.id2label = id2label\n",
    "        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, label2id=label2id, id2label=id2label)\n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "        \n",
    "        # Metrics\n",
    "        self.val_f1 = F1Score(num_classes=num_labels, task=\"multiclass\", average=\"weighted\")\n",
    "        self.test_f1 = F1Score(num_classes=num_labels, task=\"multiclass\", average=\"weighted\")\n",
    "\n",
    "        self.val_acc = torchmetrics.Accuracy(num_classes=num_labels, task=\"multiclass\")\n",
    "        self.test_acc = torchmetrics.Accuracy(num_classes=num_labels, task=\"multiclass\")\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, labels=None):\n",
    "        return self.model(input_ids, attention_mask=attention_mask, labels=labels)\n",
    "\n",
    "    def _shared_step(self, batch, stage):\n",
    "        \"\"\"\n",
    "        A single step function for training, validation, and testing.\n",
    "        \"\"\"\n",
    "        outputs = self(batch[\"input_ids\"], batch[\"attention_mask\"], labels=batch[\"label\"])\n",
    "        logits = outputs[\"logits\"]\n",
    "        loss = outputs[\"loss\"]\n",
    "        labels = batch[\"label\"]\n",
    "\n",
    "        # Update metrics\n",
    "        if stage == \"train\":\n",
    "            self.log(\"train_loss\", loss)\n",
    "            return loss\n",
    "        \n",
    "        if stage == \"val\":\n",
    "            self.val_acc(logits, labels)\n",
    "            self.val_f1(logits, labels)\n",
    "            self.log(\"val_acc\", self.val_acc, prog_bar=True)\n",
    "            self.log(\"val_f1\", self.val_f1, prog_bar=True)\n",
    "            \n",
    "        if stage == \"test\":\n",
    "            self.test_acc(logits, labels)\n",
    "            self.test_f1(logits, labels)\n",
    "            self.log(\"test_acc\", self.test_acc, prog_bar=True)\n",
    "            self.log(\"test_f1\", self.test_f1, prog_bar=True)\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        return self._shared_step(batch, \"train\")\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        self._shared_step(batch, \"val\")\n",
    "\n",
    "    def test_step(self, batch, batch_idx):\n",
    "        self._shared_step(batch, \"test\")\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)\n",
    "        \n",
    "        # Compute the number of training steps\n",
    "        num_training_steps = len(train_loader) * self.trainer.max_epochs\n",
    "        num_warmup_steps = int(0.1 * num_training_steps)  # 10% warm-up\n",
    "\n",
    "        lr_scheduler = {\n",
    "            \"scheduler\": get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps),\n",
    "            \"interval\": \"step\",  # Update every step\n",
    "            \"frequency\": 1\n",
    "        }\n",
    "        \n",
    "        return [optimizer], [lr_scheduler]\n",
    "\n",
    "    def predict_step(self, batch, batch_idx, dataloader_idx=0):\n",
    "        outputs = self(batch[\"input_ids\"], batch[\"attention_mask\"])\n",
    "        logits = outputs[\"logits\"]\n",
    "        predictions = torch.argmax(logits, dim=-1)\n",
    "        # Convert numeric predictions to labels using id2label\n",
    "        return [self.id2label[pred.item()] for pred in predictions]\n",
    "        \n",
    "    def save_to_hub(\n",
    "        self,\n",
    "        repo_name,\n",
    "        private=False,\n",
    "        model_commit_message=\"Add fine-tuned model\",\n",
    "        tokenizer_commit_message=\"Add tokenizer\",\n",
    "        token=None\n",
    "    ):\n",
    "        \"\"\"\n",
    "        Push the model to the Hugging Face Hub\n",
    "        \"\"\"\n",
    "        self.model.push_to_hub(\n",
    "            repo_name,\n",
    "            private=private,\n",
    "            commit_message=model_commit_message,\n",
    "            token=token\n",
    "        )\n",
    "        self.tokenizer.push_to_hub(\n",
    "            repo_name,\n",
    "            private=private,\n",
    "            commit_message=tokenizer_commit_message,\n",
    "            token=token\n",
    "        )\n",
    "\n",
    "lightning_model = LightningModel(model_ckpt, num_labels=num_labels, label2id=label2id, id2label=id2label, learning_rate=5e-5, weight_decay=0.01) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38183147-5259-4f2b-96fd-ed246a9d078d",
   "metadata": {},
   "source": [
    "### Fit the Model to the Training Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "53cca578-511c-4e11-b717-299a3cd08499",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using bfloat16 Automatic Mixed Precision (AMP)\n",
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "\n",
      "  | Name     | Type                          | Params | Mode \n",
      "-------------------------------------------------------------------\n",
      "0 | model    | BertForSequenceClassification | 109 M  | eval \n",
      "1 | val_f1   | MulticlassF1Score             | 0      | train\n",
      "2 | test_f1  | MulticlassF1Score             | 0      | train\n",
      "3 | val_acc  | MulticlassAccuracy            | 0      | train\n",
      "4 | test_acc | MulticlassAccuracy            | 0      | train\n",
      "-------------------------------------------------------------------\n",
      "109 M     Trainable params\n",
      "0         Non-trainable params\n",
      "109 M     Total params\n",
      "437.941   Total estimated model params size (MB)\n",
      "4         Modules in train mode\n",
      "231       Modules in eval mode\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "95a7d3a9d49f497a90771a400b88b79e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c298a8accb5f4f02836e1649a8b58bb7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7824fd05da6746cb92b0db6efe5369c6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_f1 improved. New best score: 0.980\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "755e03c173444f84a1e57918b7304cc6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_f1 improved by 0.020 >= min_delta = 0.005. New best score: 1.000\n",
      "`Trainer.fit` stopped: `max_epochs=2` reached.\n"
     ]
    }
   ],
   "source": [
    "from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping\n",
    "from lightning.pytorch.loggers import TensorBoardLogger\n",
    "\n",
    "callbacks = [\n",
    "        ModelCheckpoint(save_top_k=1, mode=\"max\", monitor=\"val_f1\"),\n",
    "        EarlyStopping(monitor=\"val_f1\", patience=3, min_delta=0.005, mode=\"max\", verbose=True),\n",
    "    ]\n",
    "logger = TensorBoardLogger(save_dir=\"./logs\", name=\"Best-Validation-F1\")\n",
    "\n",
    "trainer = L.Trainer(\n",
    "        max_epochs=2,\n",
    "        callbacks=callbacks,\n",
    "        accelerator=\"auto\",\n",
    "        precision=\"bf16-mixed\",\n",
    "        devices=\"auto\",\n",
    "        logger=logger,\n",
    "        log_every_n_steps=10,\n",
    "    )\n",
    "\n",
    "trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39358d60-69ea-4f3d-b538-ea023250bdbf",
   "metadata": {},
   "source": [
    "### Test the Best Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0150f0f-08cc-4a7d-a879-691b5cf0f8e1",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Restoring states from the checkpoint path at ./logs/Best-Validation-F1/version_1/checkpoints/epoch=1-step=26.ckpt\n",
      "Loaded model weights from the checkpoint at ./logs/Best-Validation-F1/version_1/checkpoints/epoch=1-step=26.ckpt\n"
     ]
    }
   ],
   "source": [
    "trainer.test(lightning_model, test_loader, ckpt_path=\"best\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c153f7c0-7d02-4281-8d1d-6107ce37cd5e",
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = [\n",
    "    \"I sincerely apologize for the inconvenience you've experienced. Please allow me a moment to resolve this for you as quickly as possible.\",\n",
    "    \"I understand this isn't ideal, but could we move forward with this solution?\",\n",
    "    \"The product specifications are as follows.\",\n",
    "    \"You must be new here; you clearly don't know what you're doing.\",\n",
    "]\n",
    "dataset = TextDataset(texts, tokenizer)\n",
    "dataloader = DataLoader(dataset)\n",
    "\n",
    "outputs = trainer.predict(lightning_model, dataloaders=dataloader)\n",
    "for text, output in zip(texts, outputs):\n",
    "    print(f'\"{text}\": {output[0]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1f5799f-92fa-4253-bab3-f54f0ad2e8eb",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
